package com.cloud.tool.util;

import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableName;
import com.ruijie.dashboard.entity.master.BudgetBase;

import java.lang.reflect.Field;

public class BatchInsertUtil {
    public static void main(String[] args) {
        //批量方法名，对应mapper和xml中id
        String methodName = "batchSchedule";
        //mapper参数名称
        String paramName = "addList";
        //实际代码里面的service命名
        String serviceName = "baseInfoService";
        Class<?> entity = BudgetBase.class;
        //批量插入
        printMapper(entity.getSimpleName(), methodName, paramName);
        printXml(entity, methodName, paramName);
        //普通多线程批量插入，无事务
        printSave(entity.getSimpleName(), serviceName, paramName, 1000);
        //多线程事务，慎用
        printAddTransaction(entity.getSimpleName(), paramName, 1000);
    }

    private static void printAddTransaction(String entity, String paramName, int size) {
        System.out.println("-----------需要注入Bean，一个是Spring Boot事务管理，一个是线程池-----------");
        System.out.println("@Autowired\n" +
                "private PlatformTransactionManager transactionManager;\n" +
                "@Autowired\n" +
                "    @Qualifier(\"ioForkJoinPool\")\n" +
                "    private ForkJoinPool ioForkJoinPool;");
        System.out.println("-----------多线程事务新增操作-----------");
        System.out.println("private void batchSchedule(List<" + entity + "> " + paramName + ") {\n" +
                "        if (!CollectionUtils.isEmpty(" + paramName + ")) {\n" +
                "            //定义局部变量，是否成功、顺序标识、等待线程队列\n" +
                "            AtomicBoolean isSuccess = new AtomicBoolean(true);\n" +
                "            AtomicInteger cur = new AtomicInteger(1);\n" +
                "            List<Thread> unfinishedList = new ArrayList<>();\n" +
                "            //切分新增集合\n" +
                "            List<List<" + entity + ">> partition = Lists.partition(" + paramName + ", " + size + ");" +
                "\n" +
                "            int totalSize = partition.size();\n" +
                "            //多线程处理开始\n" +
                "            CompletableFuture<Void> future =\n" +
                "                    CompletableFuture.allOf(partition.stream().map(addPartitionList -> " +
                "CompletableFuture.runAsync(() -> {\n" +
                "                        //Spring事务内部由ThreadLocal存储事务绑定信息，因此需要每个线程新开一个事务\n" +
                "                        DefaultTransactionDefinition defGo = new DefaultTransactionDefinition();\n" +
                "                        defGo.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRES_NEW)" +
                ";\n" +
                "                        TransactionStatus statusGo = transactionManager.getTransaction(defGo);\n" +
                "                        int curInt = cur.getAndIncrement();\n" +
                "                        try {\n" +
                "                            log.info(\"当前是第{}个线程开始启动，线程名={}\", curInt, Thread.currentThread()" +
                ".getName());\n" +
                "                            baseInfoService.getBaseMapper().batchSchedule(addPartitionList);\n" +
                "                            log.info(\"当前是第{}个线程完成批量插入，开始加入等待队列，线程名={}\", curInt, Thread" +
                ".currentThread().getName());\n" +
                "                            //ArrayList线程不安全，多线程会出现数据覆盖，体现为数据丢失\n" +
                "                            synchronized (unfinishedList) {\n" +
                "                                unfinishedList.add(Thread.currentThread());\n" +
                "                            }\n" +
                "                            log.info(\"当前是第{}个线程已加入队列，开始休眠，线程名={}\", curInt, Thread.currentThread()" +
                ".getName());\n" +
                "                            notifyAllThread(unfinishedList, totalSize, false);\n" +
                "                            LockSupport.park();\n" +
                "                            if (isSuccess.get()) {\n" +
                "                                log.info(\"当前是第{}个线程提交，线程名={}\", curInt, Thread.currentThread()" +
                ".getName());\n" +
                "                                transactionManager.commit(statusGo);\n" +
                "                            } else {\n" +
                "                                log.info(\"当前是第{}个线程回滚，线程名={}\", curInt, Thread.currentThread()" +
                ".getName());\n" +
                "                                transactionManager.rollback(statusGo);\n" +
                "                            }\n" +
                "                        } catch (Exception e) {\n" +
                "                            log.error(\"当前是第{}个线程出现异常，线程名={}\", curInt, Thread.currentThread()" +
                ".getName(), e);\n" +
                "                            transactionManager.rollback(statusGo);\n" +
                "                            isSuccess.set(false);\n" +
                "                            notifyAllThread(unfinishedList, totalSize, true);\n" +
                "                        }\n" +
                "                    }, ioForkJoinPool)).toArray(CompletableFuture[]::new));\n" +
                "            future.join();\n" +
                "        }\n" +
                "    }\n" +
                "private void notifyAllThread(List<Thread> unfinishedList, int totalSize, boolean isForce) {\n" +
                "        if (isForce || unfinishedList.size() == totalSize) {\n" +
                "            log.info(\"唤醒当前所有休眠线程，线程数={}，总线程数={},是否强制={}\", unfinishedList.size(), totalSize, " +
                "isForce);\n" +
                "            for (Thread thread : unfinishedList) {\n" +
                "                log.info(\"当前线程={}被唤醒\", thread.getName());\n" +
                "                LockSupport.unpark(thread);\n" +
                "            }\n" +
                "        }\n" +
                "    }");
    }

    public static void printSave(String entity, String serviceName, String paramName, int size) {
        System.out.println("-----------需要引入包-----------");
        System.out.println("import com.google.common.collect.Lists;");
        System.out.println("-----------并行操作-----------");
        System.out.println("if (CollectionUtils.isNotEmpty(" + paramName + ")) {\n" +
                "     List<List<" + entity + ">> partition = Lists.partition(" + paramName + ", " + size + ");\n" +
                "            CompletableFuture.allOf(partition.stream().map(addPartitionList ->\n" +
                "                                    CompletableFuture.runAsync(() -> " + serviceName +
                ".getBaseMapper().batchSchedule(addPartitionList)))\n" +
                "                            .toArray(CompletableFuture[]::new))\n" +
                "                    .exceptionally(e -> {\n" +
                "                log.error(\"多线程处理异常\", e);\n" +
                "                return null;\n" +
                "            });\n" +
                "        }");
    }

    public static void printMapper(String clazz, String methodName, String paramName) {
        System.out.println("-----------生成mapper方法-----------");
        System.out.println("void " + methodName + "(@Param(\"" + paramName + "\") List<" + clazz + ">" +
                " " + paramName + ");");
    }

    public static void printXml(Class<?> clazz, String methodName, String paramName) {
        System.out.println("-----------生成XML语句-----------");
        if (clazz == null) {
            return;
        }
        String tableName = clazz.getAnnotation(TableName.class).value();
        Field[] fields = clazz.getDeclaredFields();
        StringBuilder setString = new StringBuilder();
        StringBuilder valueString = new StringBuilder();
        int index = 0;
        for (Field field : fields) {
            if (field.getAnnotation(TableField.class) != null) {
                setString.append(field.getAnnotation(TableField.class).value()).append(",");
                valueString.append("#{").append("res").append(".").append(field.getName()).append("},");
                if (index % 20 == 0) {
                    setString.append("\n\r");
                    valueString.append("\n\r");
                }
            }
            index++;
        }
        setString.replace(setString.lastIndexOf(","), setString.length(), "");
        valueString.replace(valueString.lastIndexOf(","), valueString.length(), "");
        String result = "<insert id=\"" + methodName + "\" parameterType=\"java.util.List\">" + "\n\r" +
                "INSERT INTO " + tableName + "\n\r" +
                "(" + "\n\r" +
                setString + "\n\r" +
                ")" + "\n\r" +
                "VALUES " + "\n\r" +
                "<foreach collection=\"" + paramName + "\" item=\"" + "res" + "\" separator=\",\">" + "\n\r" +
                "(" + "\n\r" +
                valueString + "\n\r" +
                ")" + "\n\r" +
                "</foreach>" + "\n\r" +
                "</insert>";
        System.out.println(result);
    }
}